# CUDA discovery and dependency management

using Pkg, Pkg.Artifacts
using Libdl


## global state

const __version = Ref{VersionNumber}()

"""
    version()

Returns the version of the CUDA toolkit in use.
"""
version() = @after_init(__version[])

"""
    release()

Returns the CUDA release part of the version as returned by [`version`](@ref).
"""
release() = @after_init(VersionNumber(__version[].major, __version[].minor))

const __nvdisasm = Ref{String}()
const __libdevice = Ref{String}()
const __libcudadevrt = Ref{String}()
const __libcupti = Ref{Union{Nothing,String}}()
const __libnvtx = Ref{Union{Nothing,String}}()

nvdisasm() = @after_init(__nvdisasm[])
libdevice() = @after_init(__libdevice[])
libcudadevrt() = @after_init(__libcudadevrt[])
function libcupti()
    @after_init begin
        @assert has_cupti() "This functionality is unavailable as CUPTI is missing."
        __libcupti[]
    end
end
function libnvtx()
    @after_init begin
        @assert has_nvtx() "This functionality is unavailable as NVTX is missing."
        __libnvtx[]
    end
end

export has_cupti, has_nvtx
has_cupti() = @after_init(__libcupti[]) !== nothing
has_nvtx() = @after_init(__libnvtx[]) !== nothing


## discovery

# NOTE: we don't use autogenerated JLLs, because we have multiple artifacts and need to
#       decide at run time (i.e. not via package dependencies) which one to use.
const cuda_artifacts = Dict(
    v"10.2" => ()->artifact"CUDA10.2",
    v"10.1" => ()->artifact"CUDA10.1",
    v"10.0" => ()->artifact"CUDA10.0",
    v"9.2"  => ()->artifact"CUDA9.2",
    v"9.0"  => ()->artifact"CUDA9.0",
)

# try use CUDA from an artifact
function use_artifact_cuda()
    @debug "Trying to use artifacts..."

    # select compatible artifacts
    if haskey(ENV, "JULIA_CUDA_VERSION")
        wanted_version = VersionNumber(ENV["JULIA_CUDA_VERSION"])
        filter!(((version,artifact),) -> version == wanted_version, cuda_artifacts)
    else
        driver_version = CUDAdrv.release()
        filter!(((version,artifact),) -> version <= driver_version, cuda_artifacts)
    end

    # download and install
    artifact = nothing
    for release in sort(collect(keys(cuda_artifacts)); rev=true)
        try
            artifact = (release=release, dir=cuda_artifacts[release]())
            break
        catch
        end
    end
    if artifact == nothing
        @debug "Could not find a compatible artifact."
        return false
    end

    # utilities to look up stuff in the artifact (at known locations, so not using CUDAapi)
    get_binary(name) = joinpath(artifact.dir, "bin", Sys.iswindows() ? "$name.exe" : name)
    function get_library(name)
        filename = if Sys.iswindows()
            "$name.dll"
        elseif Sys.isapple()
            "lib$name.dylib"
        else
            "lib$name.so"
        end
        joinpath(artifact.dir, Sys.iswindows() ? "bin" : "lib", filename)
    end
    get_static_library(name) = joinpath(artifact.dir, "lib", Sys.iswindows() ? "$name.lib" : "lib$name.a")
    get_file(path) = joinpath(artifact.dir, path)

    __nvdisasm[] = get_binary("nvdisasm")
    @assert isfile(__nvdisasm[])
    __version[] = parse_toolkit_version(__nvdisasm[])

    # Windows libraries are tagged with the CUDA release
    long = "$(artifact.release.major)$(artifact.release.minor)"
    short = artifact.release >= v"10.1" ? string(artifact.release.major) : long

    __libcupti[] = get_library(Sys.iswindows() ? "cupti64_$long" : "cupti")
    @assert isfile(__libcupti[])
    __libnvtx[] = get_library(Sys.iswindows() ? "nvToolsExt64_1" : "nvToolsExt")
    @assert isfile(__libnvtx[])

    __libcudadevrt[] = get_static_library("cudadevrt")
    @assert isfile(__libcudadevrt[])
    __libdevice[] = get_file(joinpath("share", "libdevice", "libdevice.10.bc"))
    @assert isfile(__libdevice[])

    @debug "Using CUDA $(__version[]) from an artifact at $(artifact.dir)"
    return true
end

# try to use CUDA from a local installation
function use_local_cuda()
    @debug "Trying to use local installation..."

    cuda_dirs = find_toolkit()

    let path = find_cuda_binary("nvdisasm", cuda_dirs)
        if path === nothing
            @debug "Could not find nvdisasm"
            return false
        end
        __nvdisasm[] = path
    end

    cuda_version = parse_toolkit_version(__nvdisasm[])
    __version[] = cuda_version

    cupti_dirs = map(dir->joinpath(dir, "extras", "CUPTI"), cuda_dirs) |> x->filter(isdir,x)
    __libcupti[] = find_cuda_library("cupti", [cuda_dirs; cupti_dirs], [cuda_version])
    __libnvtx[] = find_cuda_library("nvtx", cuda_dirs, [v"1"])

    let path = find_libcudadevrt(cuda_dirs)
        if path === nothing
            @debug "Could not find libcudadevrt"
            return false
        end
        __libcudadevrt[] = path
    end
    let path = find_libdevice(cuda_dirs)
        if path === nothing
            @debug "Could not find libdevice"
            return false
        end
        __libdevice[] = path
    end

    @debug "Found local CUDA $(cuda_version) at $(join(cuda_dirs, ", "))"
    return true
end

function __configure_dependencies__(show_reason::Bool)
    found = false

    if parse(Bool, get(ENV, "JULIA_CUDA_USE_BINARYBUILDER", "true"))
        found = use_artifact_cuda()
    end

    if !found
        found = use_local_cuda()
    end

    if !found
        show_reason && @error "Could not find a suitable CUDA installation"
    end

    return found
end
